{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from rastervision.core.data import ClassConfig, SemanticSegmentationLabels\n",
    "\n",
    "import albumentations as A\n",
    "\n",
    "from rastervision.pytorch_learner import (\n",
    "    SemanticSegmentationRandomWindowGeoDataset,\n",
    "    SemanticSegmentationSlidingWindowGeoDataset,\n",
    "    SemanticSegmentationVisualizer,\n",
    "    SemanticSegmentationGeoDataConfig,\n",
    "    SemanticSegmentationLearnerConfig,\n",
    "    SolverConfig,\n",
    "    SemanticSegmentationLearner,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"AWS_NO_SIGN_REQUEST\"] = \"YES\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_config = ClassConfig(\n",
    "    names=[\"background\", \"building\"],\n",
    "    colors=[\"lightgray\", \"darkred\"],\n",
    "    null_class=\"background\",\n",
    ")\n",
    "\n",
    "viz = SemanticSegmentationVisualizer(\n",
    "    class_names=class_config.names, class_colors=class_config.colors\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/images/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13.tif\"\n",
    "train_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0331E-1257N_1327_3160_13/labels/global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160_13_Buildings.geojson\"\n",
    "\n",
    "val_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13.tif\"\n",
    "val_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2018_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_image_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/images/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13.tif\"\n",
    "pred_label_uri = \"s3://spacenet-dataset/spacenet/SN7_buildings/train/L15-0357E-1223N_1429_3296_13/labels/global_monthly_2020_01_mosaic_L15-0357E-1223N_1429_3296_13_Buildings.geojson\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_augmentation_transform = A.Compose(\n",
    "    [\n",
    "        A.Flip(),\n",
    "        A.ShiftScaleRotate(),\n",
    "        A.OneOf(\n",
    "            [\n",
    "                A.HueSaturationValue(hue_shift_limit=10),\n",
    "                A.RGBShift(),\n",
    "                A.ToGray(),\n",
    "                A.ToSepia(),\n",
    "                A.RandomBrightness(),\n",
    "                A.RandomGamma(),\n",
    "            ]\n",
    "        ),\n",
    "        A.CoarseDropout(max_height=32, max_width=32, max_holes=5),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(\n",
    "    class_config=class_config,\n",
    "    image_uri=train_image_uri,\n",
    "    label_vector_uri=train_label_uri,\n",
    "    label_vector_default_class_id=class_config.get_class_id(\"building\"),\n",
    "    size_lims=(150, 200),\n",
    "    out_size=256,\n",
    "    max_windows=400,\n",
    "    transform=data_augmentation_transform,\n",
    ")\n",
    "\n",
    "len(train_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = viz.get_batch(train_ds, 4)\n",
    "viz.plot_batch(x, y, show=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(\n",
    "    class_config=class_config,\n",
    "    image_uri=val_image_uri,\n",
    "    label_vector_uri=val_label_uri,\n",
    "    label_vector_default_class_id=class_config.get_class_id(\"building\"),\n",
    "    size=200,\n",
    "    stride=100,\n",
    "    transform=A.Resize(256, 256),\n",
    ")\n",
    "len(val_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = viz.get_batch(val_ds, 4)\n",
    "viz.plot_batch(x, y, show=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(\n",
    "    class_config=class_config,\n",
    "    image_uri=pred_image_uri,\n",
    "    size=200,\n",
    "    stride=100,\n",
    "    transform=A.Resize(256, 256),\n",
    ")\n",
    "len(pred_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.hub.load(\n",
    "    \"AdeelH/pytorch-fpn:0.3\",\n",
    "    \"make_fpn_resnet\",\n",
    "    name=\"resnet18\",\n",
    "    fpn_type=\"panoptic\",\n",
    "    num_classes=len(class_config),\n",
    "    fpn_channels=128,\n",
    "    in_channels=3,\n",
    "    out_size=(256, 256),\n",
    "    pretrained=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_cfg = SemanticSegmentationGeoDataConfig(\n",
    "    class_names=class_config.names,\n",
    "    class_colors=class_config.colors,\n",
    "    num_workers=0,  # increase to use multi-processing\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solver_cfg = SolverConfig(batch_sz=8, lr=3e-2, class_loss_weights=[1.0, 10.0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner_cfg = SemanticSegmentationLearnerConfig(data=data_cfg, solver=solver_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = SemanticSegmentationLearner(\n",
    "    cfg=learner_cfg,\n",
    "    output_dir=\"./train-demo/\",\n",
    "    model=model,\n",
    "    train_ds=train_ds,\n",
    "    valid_ds=val_ds,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.log_data_stats()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%tensorboard --bind_all --logdir \"./train-demo/tb-logs\" --reload_interval 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.train(epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.train(epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.plot_predictions(split=\"valid\", show=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.save_model_bundle()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = SemanticSegmentationLearner.from_model_bundle(\n",
    "    model_bundle_uri=\"./train-demo/model-bundle.zip\",\n",
    "    output_dir=\"./train-demo/\",\n",
    "    model=model,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = SemanticSegmentationLearner.from_model_bundle(\n",
    "    model_bundle_uri=\"./train-demo/model-bundle.zip\",\n",
    "    output_dir=\"./train-demo/\",\n",
    "    model=model,\n",
    "    train_ds=train_ds,\n",
    "    valid_ds=val_ds,\n",
    "    training=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.train(epochs=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner.plot_predictions(split=\"valid\", show=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = learner.predict_dataset(\n",
    "    pred_ds,\n",
    "    raw_out=True,\n",
    "    numpy_out=True,\n",
    "    predict_kw=dict(out_shape=(325, 325)),\n",
    "    progress_bar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_labels = SemanticSegmentationLabels.from_predictions(\n",
    "    pred_ds.windows,\n",
    "    predictions,\n",
    "    smooth=True,\n",
    "    extent=pred_ds.scene.extent,\n",
    "    num_classes=len(class_config),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = pred_labels.get_score_arr(pred_labels.extent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_labels.save(\n",
    "    uri=f\"predict\",\n",
    "    crs_transformer=pred_ds.scene.raster_source.crs_transformer,\n",
    "    class_config=class_config,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
